"""Attribute extractor module for extracting attributes from Bedrock trace data."""

from __future__ import annotations

import json
import logging
import uuid
from typing import Any, Dict, List, Optional

from opentelemetry.util.types import AttributeValue

from openinference.instrumentation import (
    Message,
    TokenCount,
    ToolCall,
    ToolCallFunction,
    get_input_attributes,
    get_llm_attributes,
    get_llm_input_message_attributes,
    get_llm_invocation_parameter_attributes,
    get_llm_model_name_attributes,
    get_llm_output_message_attributes,
    get_llm_token_count_attributes,
    get_output_attributes,
    get_span_kind_attributes,
    get_tool_attributes,
)
from openinference.instrumentation.bedrock.utils.json_utils import (
    fix_loose_json_string,
    safe_json_loads,
)
from openinference.semconv.trace import (
    DocumentAttributes,
    OpenInferenceLLMProviderValues,
    OpenInferenceMimeTypeValues,
    OpenInferenceSpanKindValues,
    SpanAttributes,
)

logger = logging.getLogger(__name__)


class AttributeExtractor:
    """
    Extracts attributes from Bedrock trace data.

    This class provides methods to extract and process attributes from various types of
    trace data generated by Amazon Bedrock services. It handles different types of inputs
    and outputs, including model invocations, tool calls, knowledge base lookups, and more.
    """

    @classmethod
    def get_messages_object(cls, input_text: str) -> Any:
        """
        Parse input text into a list of Message objects.

        Args:
            input_text (str): The input text to parse.

        Returns:
            list[Message]: A list of parsed Message objects.
        """
        messages = list()
        try:
            input_messages = safe_json_loads(input_text)
            if system_message := input_messages.get("system"):
                messages.append(Message(content=system_message, role="system"))

            for message in input_messages.get("messages", []):
                role = message.get("role", "")
                if content := message.get("content"):
                    parsed_contents = fix_loose_json_string(content) or [content]
                    for parsed_content in parsed_contents:
                        message_content = content
                        if isinstance(parsed_content, dict):
                            if parsed_content_type := parsed_content.get("type"):
                                message_content = parsed_content.get(parsed_content_type, "")
                        messages.append(Message(content=message_content, role=role))
        except Exception:
            return [Message(content=input_text, role="assistant")]
        return messages

    @classmethod
    def get_attributes_from_message(cls, message: Dict[str, Any], role: str) -> Optional[Message]:
        """
        Extract attributes from a message dictionary.

        Args:
            message (dict[str, Any]): The message dictionary.
            role (str): The role of the message.

        Returns:
            Message | None: A Message object if attributes can be extracted, None otherwise.
        """
        if message.get("type") == "text":
            return Message(content=message.get("text", ""), role=role)
        if message.get("type") == "tool_use":
            tool_call_function = ToolCallFunction(
                name=message.get("name", ""), arguments=message.get("input", {})
            )
            tool_calls = [ToolCall(id=message.get("id", ""), function=tool_call_function)]
            return Message(tool_call_id=message.get("id", ""), role="tool", tool_calls=tool_calls)
        return None

    @classmethod
    def get_output_messages(cls, model_output: dict[str, Any]) -> Any:
        """
        Extract output messages from model output.

        Args:
            model_output (dict[str, Any]): The model output dictionary.

        Returns:
            list[Message] | None: A list of Message objects if messages can be extracted,
            None otherwise.
        """
        messages = list()
        if raw_response := model_output.get("rawResponse"):
            if output_text := raw_response.get("content"):
                try:
                    data = json.loads(str(output_text))
                    for content in data.get("content") or []:
                        if message := cls.get_attributes_from_message(
                            content, content.get("role", "assistant")
                        ):
                            messages.append(message)
                except Exception:
                    messages.append(Message(content=str(output_text), role="assistant"))
        return messages

    @classmethod
    def get_attributes_from_model_invocation_input(
        cls, model_invocation_input: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from model invocation input.

        This method processes the model invocation input to extract relevant attributes
        such as the model name, invocation parameters, and input messages. It combines
        these attributes with LLM-specific attributes and span kind attributes.

        Args:
            model_invocation_input (dict[str, Any]): The model invocation input dictionary.

        Returns:
            dict[str, Any]: A dictionary of extracted attributes.
        """
        llm_attributes = {}

        # Get input text
        input_text = ""
        if model_invocation_input and "text" in model_invocation_input:
            input_text = model_invocation_input["text"]

        # Get model name and invocation parameters
        if model_name := AttributeExtractor.get_model_name(model_invocation_input or {}, {}):
            llm_attributes["model_name"] = model_name

        if invocation_parameters := AttributeExtractor.get_invocation_parameters(
            model_invocation_input or {}, {}
        ):
            llm_attributes["invocation_parameters"] = invocation_parameters

        # Get input and output messages
        llm_attributes["input_messages"] = AttributeExtractor.get_messages_object(input_text)

        # Set attributes
        return {
            **get_llm_attributes(**llm_attributes, provider=OpenInferenceLLMProviderValues.AWS),  # type: ignore
            **get_span_kind_attributes(OpenInferenceSpanKindValues.LLM),
            **get_input_attributes(input_text),
        }

    @classmethod
    def get_attributes_from_model_invocation_output(
        cls, model_invocation_output: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from model invocation output.

        This method processes the model invocation output to extract relevant attributes
        such as the model name, invocation parameters, output messages, and token counts.
        It combines these attributes with LLM-specific attributes and output attributes.

        Args:
            model_invocation_output (dict[str, Any]): The model invocation output dictionary.

        Returns:
            dict[str, Any]: A dictionary of extracted attributes.
        """
        llm_attributes = {}
        if model_name := AttributeExtractor.get_model_name({}, model_invocation_output or {}):
            llm_attributes["model_name"] = model_name

        if invocation_parameters := AttributeExtractor.get_invocation_parameters(
            {}, model_invocation_output or {}
        ):
            llm_attributes["invocation_parameters"] = invocation_parameters

        # Get input and output messages
        llm_attributes["output_messages"] = AttributeExtractor.get_output_messages(
            model_invocation_output or {}
        )

        # Set attributes
        request_attributes = {
            **get_llm_attributes(**llm_attributes, provider=OpenInferenceLLMProviderValues.AWS),  # type: ignore
            **get_llm_token_count_attributes(
                AttributeExtractor.get_token_counts(model_invocation_output or {})
            ),
        }
        # Set output value
        if output_value := AttributeExtractor.get_output_value(model_invocation_output or {}):
            request_attributes = {**request_attributes, **get_output_attributes(output_value)}
        return request_attributes

    @classmethod
    def get_attributes_from_code_interpreter_input(
        cls, code_input: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from code interpreter input.

        Args:
            code_input (dict[str, Any]): The code interpreter input dictionary.

        Returns:
            dict[str, Any]: A dictionary of extracted attributes.
        """
        tool_call_function = ToolCallFunction(
            name="code_interpreter",
            arguments={"code": code_input.get("code", ""), "files": code_input.get("files", "")},
        )
        tool_calls = [ToolCall(id="default", function=tool_call_function)]
        messages = [Message(tool_call_id="default", role="tool", tool_calls=tool_calls)]
        name = "code_interpreter"
        description = "Executes code and returns results"
        parameters = json.dumps({"code": {"type": "string", "description": "Code to execute"}})
        metadata = {
            "invocation_type": "code_execution",
            "execution_context": code_input.get("context", {}),
        }
        return {
            **get_input_attributes(code_input.get("code", "")),
            **get_span_kind_attributes(OpenInferenceSpanKindValues.TOOL),
            **get_llm_input_message_attributes(messages),
            **get_tool_attributes(name=name, description=description, parameters=parameters),
            **{"metadata": metadata},
        }

    @classmethod
    def get_attributes_from_knowledge_base_lookup_input(
        cls, kb_data: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from knowledge base lookup input.

        Args:
            kb_data (dict[str, Any]): The knowledge base lookup input dictionary.

        Returns:
            dict[str, Any]: A dictionary of extracted attributes.
        """
        metadata = {
            "invocation_type": "knowledge_base_lookup",
            "knowledge_base_id": kb_data.get("knowledgeBaseId"),
        }
        return {
            **get_input_attributes(kb_data.get("text", "")),
            **get_span_kind_attributes(OpenInferenceSpanKindValues.RETRIEVER),
            **{"metadata": metadata},
        }

    @classmethod
    def get_attributes_from_action_group_invocation_input(
        cls, action_input: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from action group invocation input.

        Args:
            action_input (dict[str, Any]): The action group invocation input dictionary.

        Returns:
            dict[str, Any]: A dictionary of extracted attributes.
        """
        name = action_input.get("function", "")
        tool_call_function = ToolCallFunction(
            name=name, arguments=action_input.get("parameters", {})
        )
        tool_calls = [ToolCall(id="default", function=tool_call_function)]
        messages = [Message(tool_call_id="default", role="tool", tool_calls=tool_calls)]
        description = action_input.get("description", "")
        parameters = json.dumps(action_input.get("parameters", []))
        llm_invocation_parameters = {
            "invocation_type": "action_group_invocation",
            "action_group_name": action_input.get("actionGroupName"),
            "execution_type": action_input.get("executionType"),
        }
        if invocation_id := action_input.get("invocationId"):
            llm_invocation_parameters["invocation_id"] = invocation_id
        if verb := action_input.get("verb"):
            llm_invocation_parameters["verb"] = verb
        if api_path := action_input.get("apiPath"):
            llm_invocation_parameters["api_path"] = api_path
        return {
            **get_span_kind_attributes(OpenInferenceSpanKindValues.TOOL),
            **get_llm_input_message_attributes(messages),
            **get_tool_attributes(name=name, description=description, parameters=parameters),
            **{"metadata": llm_invocation_parameters},
        }

    @classmethod
    def get_metadata_attributes(cls, trace_metadata: dict[str, Any]) -> dict[str, Any]:
        metadata: dict[str, Any] = {}
        if not trace_metadata:
            return metadata
        if client_request_id := trace_metadata.get("clientRequestId"):
            metadata["clientRequestId"] = client_request_id
        if end_time := trace_metadata.get("endTime"):
            metadata["endTime"] = end_time.timestamp() * 1_000_000_000
        if start_time := trace_metadata.get("startTime"):
            metadata["startTime"] = start_time.timestamp() * 1_000_000_000
        if operation_total_time_ms := trace_metadata.get("operationTotalTimeMs"):
            metadata["operationTotalTimeMs"] = operation_total_time_ms
        if total_time_ms := trace_metadata.get("totalTimeMs"):
            metadata["totalTimeMs"] = total_time_ms
        return metadata

    @classmethod
    def get_observation_metadata_attributes(cls, trace_metadata: dict[str, Any]) -> dict[str, Any]:
        metadata: dict[str, Any] = {}
        if client_request_id := trace_metadata.get("clientRequestId"):
            metadata["clientRequestId"] = client_request_id
        if end_time := trace_metadata.get("endTime"):
            metadata["endTime"] = int(end_time.timestamp() * 1_000_000_000)
        if start_time := trace_metadata.get("startTime"):
            metadata["startTime"] = int(start_time.timestamp() * 1_000_000_000)
        if operation_total_time_ms := trace_metadata.get("operationTotalTimeMs"):
            metadata["operationTotalTimeMs"] = operation_total_time_ms
        if total_time_ms := trace_metadata.get("totalTimeMs"):
            metadata["totalTimeMs"] = total_time_ms
        return metadata

    @classmethod
    def get_attributes_from_agent_collaborator_invocation_input(
        cls, collaborator_input: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from agent collaborator invocation input.

        Args:
            collaborator_input (dict[str, Any]): The agent collaborator invocation input dictionary.

        Returns:
            dict[str, Any]: A dictionary of extracted attributes.
        """
        input_data = collaborator_input.get("input", {})
        input_type = input_data.get("type", "TEXT")

        # Extract content based on input type
        content = ""
        if input_type == "TEXT":
            content = input_data.get("text", "")
        elif input_type == "RETURN_CONTROL":
            if return_control_results := input_data.get("returnControlResults"):
                content = json.dumps(return_control_results)

        # Create message
        messages = [Message(content=content, role="assistant")]

        # Create metadata
        metadata = {
            "invocation_type": "agent_collaborator_invocation",
            "agent_collaborator_name": collaborator_input.get("agentCollaboratorName"),
            "agent_collaborator_alias_arn": collaborator_input.get("agentCollaboratorAliasArn"),
            "input_type": input_type,
        }

        return {
            **get_span_kind_attributes(OpenInferenceSpanKindValues.AGENT),
            **get_input_attributes(content),
            **get_llm_input_message_attributes(messages),
            **{"metadata": metadata},
        }

    @classmethod
    def get_attributes_from_code_interpreter_output(
        cls, code_invocation_output: dict[str, Any]
    ) -> Dict[str, AttributeValue]:
        """
        Extract attributes from code interpreter output.

        Args:
            code_invocation_output (dict[str, Any]): The code interpreter output dictionary.

        Returns:
            Dict[str, AttributeValue]: A dictionary of extracted attributes.
        """
        output_value = None
        files = None

        if output_text := code_invocation_output.get("executionOutput"):
            output_value = output_text
        elif execution_error := code_invocation_output.get("executionError"):
            output_value = execution_error
        elif code_invocation_output.get("executionTimeout"):
            output_value = "Execution Timeout Error"
        elif files := code_invocation_output.get("files"):
            output_value = json.dumps(files)

        content = json.dumps(files) if files else str(output_value) if output_value else ""
        messages = [Message(role="tool", content=content)]
        return {
            **get_output_attributes(output_value),
            **get_llm_output_message_attributes(messages),
        }

    @classmethod
    def get_attributes_from_agent_collaborator_invocation_output(
        cls, collaborator_output: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from agent collaborator invocation output.

        Args:
            collaborator_output (dict[str, Any]): The agent collaborator invocation
            output dictionary.

        Returns:
            Dict[str, AttributeValue]: A dictionary of extracted attributes.
        """
        output_data = collaborator_output.get("output", {})
        output_type = output_data.get("type", "TEXT")

        # Extract content based on output type
        output_value = ""
        if output_type == "TEXT":
            output_value = output_data.get("text", "")
        elif output_type == "RETURN_CONTROL":
            if return_control_payload := output_data.get("returnControlPayload"):
                output_value = json.dumps(return_control_payload)

        # Create message
        messages = [Message(role="assistant", content=output_value)]

        # Create metadata
        metadata = {
            "agent_collaborator_name": collaborator_output.get("agentCollaboratorName"),
            "agent_collaborator_alias_arn": collaborator_output.get("agentCollaboratorAliasArn"),
            "output_type": output_type,
        }

        return {
            **get_output_attributes(output_value),
            **get_llm_output_message_attributes(messages),
            **{"metadata": metadata},
        }

    @classmethod
    def get_document_attributes(cls, index: int, ref: Dict[str, Any]) -> Dict[str, Any]:
        """
        Extract document attributes from a retrieved reference for OpenTelemetry tracing.

        This method processes a single document reference from Bedrock retrieval results
        and extracts relevant attributes including document ID, content, score, and metadata.
        The attributes are formatted according to OpenInference semantic conventions.

        Args:
            index (int): The index position of this document in the retrieval results,
                        used to create unique attribute keys.
            ref (Dict[str, Any]): A document reference dictionary containing:
                - metadata: Document metadata including chunk ID
                - content: Document content with text and type information
                - score: Relevance score for the retrieved document
                - location: Document location information

        Returns:
            Dict[str, Any]: A dictionary of OpenTelemetry attributes with keys formatted as:
                - retrieval.documents.{index}.document.id: Document chunk ID
                - retrieval.documents.{index}.document.content: Document text content
                - retrieval.documents.{index}.document.score: Relevance score
                - retrieval.documents.{index}.document.metadata: JSON-encoded metadata

        Note:
            The method follows OpenInference semantic conventions for document attributes
            and handles missing fields gracefully by using default values.
        """
        attributes = {}
        base_key = f"{RETRIEVAL_DOCUMENTS}.{index}"
        if document_id := ref.get("metadata", {}).get("x-amz-bedrock-kb-chunk-id", ""):
            attributes[f"{base_key}.{DOCUMENT_ID}"] = document_id

        if document_content := ref.get("content", {}).get("text"):
            attributes[f"{base_key}.{DOCUMENT_CONTENT}"] = document_content

        if document_score := ref.get("score", 0.0):
            attributes[f"{base_key}.{DOCUMENT_SCORE}"] = document_score
        metadata = json.dumps(
            {
                "location": ref.get("location", {}),
                "metadata": ref.get("metadata", {}),
                "type": ref.get("content", {}).get("type"),
            }
        )
        attributes[f"{base_key}.{DOCUMENT_METADATA}"] = metadata
        return attributes

    @classmethod
    def get_attributes_from_knowledge_base_lookup_output(
        cls, retrieved_refs: List[Dict[str, Any]]
    ) -> dict[str, AttributeValue]:
        """
        Extract attributes from knowledge base lookup output.

        Args:
            retrieved_refs (list): The documents list.

        Returns:
            Dict[str, AttributeValue]: A dictionary of extracted attributes.
        """
        attributes: Dict[str, Any] = {}
        for i, ref in enumerate(retrieved_refs):
            attributes |= cls.get_document_attributes(i, ref)
        return attributes

    @classmethod
    def get_event_type(cls, trace_data: dict[str, Any]) -> str:
        """
        Identifies the type of trace event from the provided trace data.

        Args:
            trace_data (dict[str, Any]): The trace data containing information
            about the event.

        Returns:
            str: The identified event type if found, otherwise an empty string.
        """
        trace_events = [
            "preProcessingTrace",
            "orchestrationTrace",
            "guardrailTrace",
            "postProcessingTrace",
            "failureTrace",
        ]
        for trace_event in trace_events:
            if trace_event in trace_data:
                return trace_event
        return ""

    @classmethod
    def get_chunk_type(cls, trace_event_data: dict[str, Any]) -> str:
        """
        Identifies the type of trace event from the provided trace data.

        Args:
            trace_event str: The trace event type.
            trace_event_data (dict[str, Any]): The trace data containing information
            about the chunk.

        Returns:
            str: The identified event type if found, otherwise an empty string.
        """
        chunk_types = [
            "modelInvocationInput",
            "modelInvocationOutput",
            "invocationInput",
            "observation",
            "rationale",
        ]
        for chunk_type in chunk_types:
            if chunk_type in trace_event_data:
                return chunk_type
        return ""

    @classmethod
    def get_attributes_from_invocation_input(
        cls, invocation_input: dict[str, Any]
    ) -> dict[str, Any]:
        """
        Extract attributes from invocation input.

        Args:
            invocation_input (dict[str, Any]): The trace data dictionary.

        Returns:
            dict[str, Any] | None: A dictionary of extracted attributes if available,
            None otherwise.
        """
        if "actionGroupInvocationInput" in invocation_input:
            return (
                cls.get_attributes_from_action_group_invocation_input(
                    invocation_input["actionGroupInvocationInput"]
                )
                or {}
            )
        if "codeInterpreterInvocationInput" in invocation_input:
            return (
                cls.get_attributes_from_code_interpreter_input(
                    invocation_input["codeInterpreterInvocationInput"]
                )
                or {}
            )
        if "knowledgeBaseLookupInput" in invocation_input:
            return (
                cls.get_attributes_from_knowledge_base_lookup_input(
                    invocation_input["knowledgeBaseLookupInput"]
                )
                or {}
            )
        if "agentCollaboratorInvocationInput" in invocation_input:
            return (
                cls.get_attributes_from_agent_collaborator_invocation_input(
                    invocation_input["agentCollaboratorInvocationInput"]
                )
                or {}
            )
        return {}

    @classmethod
    def get_attributes_from_observation(
        cls, observation: dict[str, Any]
    ) -> Dict[str, AttributeValue]:
        """
        Extract attributes from observation data.

        Args:
            observation (dict[str, Any]): The trace data dictionary.

        Returns:
            Dict[str, AttributeValue]: A dictionary of extracted attributes.
        """
        if "actionGroupInvocationOutput" in observation:
            tool_output = observation["actionGroupInvocationOutput"]
            return get_output_attributes(tool_output.get("text", ""))
        if "codeInterpreterInvocationOutput" in observation:
            return cls.get_attributes_from_code_interpreter_output(
                observation["codeInterpreterInvocationOutput"]
            )
        if "knowledgeBaseLookupOutput" in observation:
            return cls.get_attributes_from_knowledge_base_lookup_output(
                observation["knowledgeBaseLookupOutput"].get("retrievedReferences", [])
            )
        if "agentCollaboratorInvocationOutput" in observation:
            return cls.get_attributes_from_agent_collaborator_invocation_output(
                observation["agentCollaboratorInvocationOutput"]
            )
        return {}

    @classmethod
    def get_metadata_from_observation(
        cls, observation: dict[str, Any]
    ) -> Dict[str, AttributeValue]:
        """
        Extract attributes from observation data.

        Args:
            observation (dict[str, Any]): The trace data dictionary.

        Returns:
            Dict[str, AttributeValue]: A dictionary of extracted attributes.
        """
        events = [
            "actionGroupInvocationOutput",
            "codeInterpreterInvocationOutput",
            "knowledgeBaseLookupOutput",
            "agentCollaboratorInvocationOutput",
        ]
        for event in events:
            if event in observation and (event_data := observation[event]):
                return cls.get_metadata_attributes(event_data.get("metadata", {}))
        return {}

    @classmethod
    def get_model_name(
        cls, input_params: dict[str, Any], output_params: dict[str, Any]
    ) -> str | None:
        """
        Get the model name from input or output parameters.

        Args:
            input_params (dict[str, Any]): The input parameters.
            output_params (dict[str, Any]): The output parameters.

        Returns:
            str | None: The model name if found, None otherwise.
        """
        if model_name := input_params.get("foundationModel"):
            return str(model_name)
        if raw_response := output_params.get("rawResponse"):
            if output_text := raw_response.get("content"):
                try:
                    data = json.loads(str(output_text))
                    model = data.get("model")
                    if model is not None:
                        return str(model)
                except Exception as e:
                    logger.debug(str(e))
        return None

    @classmethod
    def get_invocation_parameters(
        cls, input_params: dict[str, Any], output_params: dict[str, Any]
    ) -> str | None:
        """
        Get the invocation parameters from input or output parameters.

        Args:
            input_params (dict[str, Any]): The input parameters.
            output_params (dict[str, Any]): The output parameters.

        Returns:
            str | None: The invocation parameters as a JSON string if found, None otherwise.
        """
        if inference_configuration := input_params.get("inferenceConfiguration"):
            return json.dumps(inference_configuration)
        if inference_configuration := output_params.get("inferenceConfiguration"):
            return json.dumps(inference_configuration)
        return None

    @classmethod
    def get_token_counts(cls, output_params: dict[str, Any]) -> TokenCount | None:
        """
        Get token counts from output parameters.

        Args:
            output_params (dict[str, Any]): The output parameters.

        Returns:
            TokenCount | None: A TokenCount object if token counts are found, None otherwise.
        """
        if not output_params.get("metadata", {}):
            return None
        if usage := output_params.get("metadata", {}).get("usage"):
            completion, prompt, total = 0, 0, 0

            if input_tokens := usage.get("inputTokens"):
                prompt = input_tokens
            if output_tokens := usage.get("outputTokens"):
                completion = output_tokens
            if (input_tokens := usage.get("inputTokens")) and (
                output_tokens := usage.get("outputTokens")
            ):
                total = input_tokens + output_tokens
            return TokenCount(prompt=prompt, completion=completion, total=total)
        return None

    @classmethod
    def get_output_value(cls, output_params: dict[str, Any]) -> str | None:
        """
        Get the output value from output parameters.

        Args:
            output_params (dict[str, Any]): The output parameters.

        Returns:
            str | None: The output value if found, None otherwise.
        """
        if raw_response := output_params.get("rawResponse"):
            if output_text := raw_response.get("content"):
                return str(output_text)

        parsed_response = output_params.get("parsedResponse", {})
        if output_text := parsed_response.get("text"):
            # This block will be executed for Post Processing trace
            return str(output_text)
        if output_text := parsed_response.get("rationale"):
            # This block will be executed for Pre Processing trace
            return str(output_text)
        return None

    @classmethod
    def get_parent_input_attributes_from_invocation_input(
        cls, invocation_input: Dict[str, Any]
    ) -> Any:
        """
        Extract parent input attributes from invocation input.

        This method extracts input attributes from various types of invocation inputs
        (action group, code interpreter, knowledge base lookup, agent collaborator)
        to be set on the parent span.

        Args:
            invocation_input (dict[str, Any]): The invocation input dictionary.

        Returns:
            Optional[dict[str, AttributeValue]]: A dictionary of input attributes if available,
            None otherwise.
        """
        if action_group := invocation_input.get("actionGroupInvocationInput", {}):
            if input_value := action_group.get("text", ""):
                return get_input_attributes(input_value)

        if code_interpreter := invocation_input.get("codeInterpreterInvocationInput", {}):
            if input_value := code_interpreter.get("code", ""):
                return get_input_attributes(input_value)

        if kb_lookup := invocation_input.get("knowledgeBaseLookupInput", {}):
            if input_value := kb_lookup.get("text", ""):
                return get_input_attributes(input_value)

        if agent_collaborator := invocation_input.get("agentCollaboratorInvocationInput", {}):
            if input_data := agent_collaborator.get("input", {}):
                if input_type := input_data.get("type"):
                    if input_type == "TEXT":
                        if input_value := input_data.get("text", ""):
                            return get_input_attributes(input_value)
                    elif input_type == "RETURN_CONTROL":
                        if return_control_results := input_data.get("returnControlResults"):
                            input_value = json.dumps(return_control_results)
                            return get_input_attributes(input_value)

        return None

    @classmethod
    def extract_trace_id(cls, trace_data: dict[str, Any]) -> Any:
        """
        Extract a unique trace ID from trace data.

        This method attempts to find a trace ID in various locations within the trace data.
        It checks the main event data, model invocation input/output, invocation input,
        observation data, and rationale data. If no trace ID is found, it generates a unique
        ID based on the event type and current span counts.

        Args:
            trace_data (dict[str, Any]): The trace data containing trace information.

        Returns:
            str: A unique trace ID extracted from the data or generated if none exists.
        """
        trace_event = AttributeExtractor.get_event_type(trace_data)
        event_data = trace_data.get(trace_event, {})

        # Try to get trace ID from the trace data
        if "traceId" in event_data:
            return event_data["traceId"]

        # For model invocation traces
        if "modelInvocationInput" in event_data:
            model_input = event_data["modelInvocationInput"]
            if "traceId" in model_input:
                return model_input["traceId"]

        if "modelInvocationOutput" in event_data:
            model_output = event_data["modelInvocationOutput"]
            if "traceId" in model_output:
                return model_output["traceId"]

        # For invocation input traces
        if "invocationInput" in event_data:
            invocation_input = event_data["invocationInput"]
            if "traceId" in invocation_input:
                return invocation_input["traceId"]

        # For observation traces
        if "observation" in event_data:
            observation = event_data["observation"]
            if "traceId" in observation:
                return observation["traceId"]
        if "rationale" in event_data:
            rationale = event_data["rationale"]
            if "traceId" in rationale:
                return rationale["traceId"]

        # Generate a unique ID if none found
        return str(uuid.uuid4())

    @classmethod
    def get_attributes_from_guardrail_trace(cls, guardrail_trace: dict[str, Any]) -> dict[str, Any]:
        """
        Extract attributes from guardrail trace data.
        """
        guardrail_trace_data = {}

        # Extract client_request_id from the guardrail metadata if present
        metadata_attributes = cls.get_metadata_attributes(guardrail_trace.get("metadata", {}))
        if client_request_id := metadata_attributes.get("clientRequestId"):
            guardrail_trace_data["clientRequestId"] = client_request_id
        if "startTime" in metadata_attributes:
            guardrail_trace_data["startTime"] = metadata_attributes["startTime"]
        if "endTime" in metadata_attributes:
            guardrail_trace_data["endTime"] = metadata_attributes["endTime"]
        if "totalTimeMs" in metadata_attributes:
            guardrail_trace_data["totalTimeMs"] = metadata_attributes["totalTimeMs"]
        if "action" in guardrail_trace:
            guardrail_trace_data["action"] = guardrail_trace["action"]
        if "inputAssessments" in guardrail_trace:
            guardrail_trace_data["inputAssessments"] = guardrail_trace["inputAssessments"]
        if "outputAssessments" in guardrail_trace:
            guardrail_trace_data["outputAssessments"] = guardrail_trace["outputAssessments"]

        return guardrail_trace_data

    @classmethod
    def is_blocked_guardrail(cls, guardrails: List[dict[str, Any]]) -> bool:
        """
        Determine whether an agent invocation was blocked by any intervening guardrails
        """

        for guardrail in guardrails:
            assessments = guardrail.get("inputAssessments", []) + guardrail.get(
                "outputAssessments", []
            )
            for assessment in assessments:
                # Check each of the assessment policy types to see if the guardrail is blocked
                if cls.is_assessment_blocked(assessment, "contentPolicy", ["filters"]):
                    return True
                if cls.is_assessment_blocked(
                    assessment, "sensitiveInformationPolicy", ["piiEntities", "regexes"]
                ):
                    return True
                if cls.is_assessment_blocked(assessment, "topicPolicy", ["topics"]):
                    return True
                if cls.is_assessment_blocked(
                    assessment, "wordPolicy", ["customWords", "managedWordLists"]
                ):
                    return True
        return False

    @classmethod
    def is_assessment_blocked(
        cls, assessment: dict[str, Any], policy_type: str, policy_filters: List[str]
    ) -> bool:
        """
        Parses through guardrail assessment to determine if the action is BLOCKED
        """
        blocked = "BLOCKED"
        policy = assessment.get(policy_type, {})

        filters = []
        for filter_type in policy_filters:
            filters += policy.get(filter_type, [])

        for filter in filters:
            if filter.get("action") == blocked:
                return True
        return False

    @classmethod
    def get_failure_trace_attributes(cls, trace_data: dict[str, Any]) -> dict[str, Any]:
        failure_message = ""
        if failure_code := trace_data.get("failureCode"):
            failure_message += f"Failure Code: {failure_code}\n"
        if failure_reason := trace_data.get("failureReason"):
            failure_message += f"Failure Reason: {failure_reason}"
        if failure_message:
            return get_output_attributes(failure_message)
        return {}

    @classmethod
    def extract_retrieve_invocation_params(cls, kwargs: dict[str, Any]) -> dict[str, Any]:
        """
        Extract invocation parameters for Bedrock retrieve operations.

        This method processes the keyword arguments from a Bedrock retrieve operation
        and extracts relevant invocation parameters including knowledge base ID,
        pagination tokens, and retrieval configuration settings.

        Args:
            kwargs (dict[str, Any]): Keyword arguments from the retrieve operation,
                                   typically containing knowledgeBaseId, nextToken,
                                   and retrievalConfiguration.

        Returns:
            dict[str, Any]: A dictionary containing extracted invocation parameters:
                - knowledgeBaseId: The ID of the knowledge base being queried
                - next_token: Pagination token for retrieving additional results (if present)
                - retrieval_configuration: Configuration settings for the retrieval (if present)
        """
        invocation_params = {"knowledgeBaseId": kwargs.get("knowledgeBaseId", "")}
        if next_token := kwargs.get("nextToken"):
            invocation_params["next_token"] = next_token
        if retrieval_configuration := kwargs.get("retrievalConfiguration", {}):
            invocation_params["retrieval_configuration"] = retrieval_configuration
        return invocation_params

    @classmethod
    def get_model_name_for_rag(cls, kwargs: Dict[str, Any]) -> str:
        """
        Extract the model name/ARN from RAG (Retrieve and Generate) operation parameters.

        This method determines the model being used in a RAG operation by examining
        the retrieveAndGenerateConfiguration. It handles both knowledge base and
        external sources configurations to extract the appropriate model ARN.

        Args:
            kwargs (Dict[str, Any]): Keyword arguments from the RAG operation,
                                   containing retrieveAndGenerateConfiguration with
                                   either knowledgeBaseConfiguration or
                                   externalSourcesConfiguration.

        Returns:
            str: The model ARN/name being used for the RAG operation. Returns empty
                 string if no model ARN is found in the configuration.

        Note:
            The method checks the configuration type and extracts the model ARN from
            the appropriate configuration section (knowledge base or external sources).
        """
        retrieve_and_generate_config = kwargs.get("retrieveAndGenerateConfiguration", {})
        if retrieve_and_generate_config.get("type") == "KNOWLEDGE_BASE":
            return str(
                retrieve_and_generate_config.get("knowledgeBaseConfiguration", {}).get(
                    "modelArn", ""
                )
            )
        return str(
            retrieve_and_generate_config.get("externalSourcesConfiguration", {}).get("modelArn", "")
        )

    @classmethod
    def extract_rag_invocation_params(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
        """
        Extract invocation parameters for RAG (Retrieve and Generate) operations.

        This method processes the keyword arguments from a Bedrock retrieve_and_generate
        operation and extracts relevant invocation parameters including configuration
        settings and session information.

        Args:
            kwargs (Dict[str, Any]): Keyword arguments from the retrieve_and_generate
                                   operation, containing configuration and session parameters.

        Returns:
            Dict[str, Any]: A dictionary containing extracted invocation parameters:
                - retrieveAndGenerateConfiguration: Configuration for the RAG operation (if present)
                - sessionConfiguration: Session-specific configuration settings (if present)
                - sessionId: Unique identifier for the session (if present)

        Note:
            Only parameters that are present in the input kwargs are included in the
            returned dictionary, allowing for flexible parameter handling.
        """
        invocation_params = {}
        if rag_configuration := kwargs.get("retrieveAndGenerateConfiguration"):
            invocation_params["retrieveAndGenerateConfiguration"] = rag_configuration
        if session_configuration := kwargs.get("sessionConfiguration"):
            invocation_params["sessionConfiguration"] = session_configuration
        if session_id := kwargs.get("sessionId"):
            invocation_params["sessionId"] = session_id
        return invocation_params

    @classmethod
    def extract_bedrock_retrieve_input_attributes(cls, kwargs: dict[str, Any]) -> dict[str, Any]:
        """
        Extract input attributes for Bedrock retrieve operations for OpenTelemetry tracing.

        This method processes the input parameters of a Bedrock retrieve operation and
        extracts relevant attributes for distributed tracing. It combines input text,
        span kind classification, and invocation parameters into a comprehensive
        attribute dictionary.

        Args:
            kwargs (dict[str, Any]): Keyword arguments from the retrieve operation,
                                   typically containing:
                                   - retrievalQuery: Query object with text field
                                   - knowledgeBaseId: ID of the knowledge base
                                   - nextToken: Pagination token (optional)
                                   - retrievalConfiguration: Retrieval settings (optional)

        Returns:
            dict[str, Any]: A dictionary of OpenTelemetry attributes containing:
                - Input attributes: Query text and MIME type
                - Span kind: Set to RETRIEVER for retrieval operations
                - Invocation parameters: Knowledge base ID, pagination, and configuration

        Note:
            The method follows OpenInference semantic conventions and sets the span
            kind to RETRIEVER to properly categorize the operation in traces.
        """
        input_text = kwargs.get("retrievalQuery", {}).get("text", "")
        return {
            **get_input_attributes(input_text, mime_type=OpenInferenceMimeTypeValues.TEXT),
            **get_span_kind_attributes(OpenInferenceSpanKindValues.RETRIEVER),
            **get_llm_invocation_parameter_attributes(
                cls.extract_retrieve_invocation_params(kwargs)
            ),
        }

    @classmethod
    def extract_bedrock_rag_input_attributes(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
        """
        Extract input attributes for Bedrock RAG (Retrieve and Generate) operations.

        This method processes the input parameters of a Bedrock retrieve_and_generate
        operation and extracts comprehensive attributes for OpenTelemetry tracing.
        It combines model information, input text, span classification, and invocation
        parameters into a complete attribute set.

        Args:
            kwargs (Dict[str, Any]): Keyword arguments from the retrieve_and_generate
                                   operation, containing:
                                   - input: Input object with text field
                                   - retrieveAndGenerateConfiguration: RAG configuration
                                   - sessionConfiguration: Session settings (optional)
                                   - sessionId: Session identifier (optional)

        Returns:
            Dict[str, Any]: A dictionary of OpenTelemetry attributes containing:
                - Model name: The LLM model being used for generation
                - Input attributes: User input text and MIME type
                - Span kind: Set to RETRIEVER for RAG operations
                - Invocation parameters: RAG configuration and session information

        Note:
            RAG operations are classified as RETRIEVER span kind since they combine
            both retrieval and generation phases in a single operation.
        """
        input_text = kwargs.get("input", {}).get("text", "")
        return {
            **get_llm_model_name_attributes(cls.get_model_name_for_rag(kwargs)),
            **get_input_attributes(input_text, mime_type=OpenInferenceMimeTypeValues.TEXT),
            **get_span_kind_attributes(OpenInferenceSpanKindValues.RETRIEVER),
            **get_llm_invocation_parameter_attributes(cls.extract_rag_invocation_params(kwargs)),
        }

    @classmethod
    def extract_bedrock_retrieve_response_attributes(
        cls, response: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Extract response attributes from Bedrock retrieve operation results.

        This method processes the response from a Bedrock retrieve operation and
        extracts document attributes from the retrieved results. It handles the
        retrievalResults array and converts each document into properly formatted
        OpenTelemetry attributes.

        Args:
            response (Dict[str, Any]): Response dictionary from the retrieve operation,
                                     containing:
                                     - retrievalResults: List of retrieved documents
                                     - nextToken: Pagination token (optional)

        Returns:
            Dict[str, Any]: A dictionary of OpenTelemetry attributes containing
                          document information for each retrieved result, including
                          document IDs, content, scores, and metadata formatted
                          according to OpenInference semantic conventions.

        Note:
            This method delegates to get_attributes_from_knowledge_base_lookup_output
            to ensure consistent document attribute formatting across different
            retrieval operation types.
        """
        documents = response.get("retrievalResults", [])
        return cls.get_attributes_from_knowledge_base_lookup_output(documents)

    @classmethod
    def extract_bedrock_rag_response_attributes(cls, response: Dict[str, Any]) -> Dict[str, Any]:
        """
        Extract response attributes from Bedrock RAG (Retrieve and Generate) operation results.

        This method processes the response from a Bedrock retrieve_and_generate operation
        and extracts both the generated output and the retrieved document citations.
        It handles the complex structure of RAG responses that include both generation
        results and retrieval citations.

        Args:
            response (Dict[str, Any]): Response dictionary from the retrieve_and_generate
                                     operation, containing:
                                     - output: Generated text output
                                     - citations: List of citations with retrieved references
                                     - sessionId: Session identifier (optional)

        Returns:
            Dict[str, Any]: A dictionary of OpenTelemetry attributes containing:
                - Document attributes: Information about each retrieved document
                  from citations, including IDs, content, scores, and metadata
                - Output attributes: The generated text response

        Note:
            The method processes citations to extract retrieved references and assigns
            sequential indices to each document for proper attribute key formatting.
            It combines both retrieval and generation aspects of the RAG operation
            into a comprehensive attribute set.
        """
        index = 0
        attributes: Dict[str, Any] = {}
        for citation in response.get("citations", []) or []:
            documents = citation.get("retrievedReferences", [])
            for document in documents:
                attributes |= cls.get_document_attributes(index, document)
                index += 1
        return {
            **attributes,
            **get_output_attributes(response.get("output", {}).get("text")),
        }


# Constants
DOCUMENT_ID = DocumentAttributes.DOCUMENT_ID
DOCUMENT_CONTENT = DocumentAttributes.DOCUMENT_CONTENT
DOCUMENT_SCORE = DocumentAttributes.DOCUMENT_SCORE
DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
